package server

import (
	"bytes"
	"encoding/binary"
	"epg/conf"
	"errors"
	"fmt"
	"net"
	"sync"
	"time"

	dmysql "github.com/go-sql-driver/mysql"
	"github.com/pingcap/tidb/mysql"

	"github.com/zeast/logs"
)

//Conn the conn interface
type Conn interface {
	sync.Locker
	Close() error
	GetStatus() uint16
	UseDB(name string, collate string) (err error)
	Exec(data []byte) ([]byte, error)
	Prepare(query string) ([][]byte, *mysqlStmt, error)
	Query(data []byte) ([][]byte, error)
	WriteCommandPacketUint32(byte, uint32) error


}

const (
	MinProtocolVersion byte = 10
	MaxPacketSize           = 1<<24 - 1
	TimeFmt                 = "2006-01-02 15:04:05.999999"
)

//CharsetNames CharsetNames
var CharsetNames = map[uint8]string{
	1:  "big5",
	3:  "dec8",
	4:  "cp850",
	6:  "hp8",
	7:  "koi8r",
	8:  "latin1",
	9:  "latin2",
	10: "swe7",
	11: "ascii",
	12: "ujis",
	13: "sjis",
	16: "hebrew",
	18: "tis620",
	19: "euckr",
	22: "koi8u",
	24: "gb2312",
	25: "greek",
	26: "cp1250",
	28: "gbk",
	30: "latin5",
	32: "armscii8",
	33: "utf8",
	35: "ucs2",
	36: "cp866",
	37: "keybcs2",
	38: "macce",
	39: "macroman",
	40: "cp852",
	41: "latin7",
	45: "utf8mb4",
	51: "cp1251",
	54: "utf16",
	56: "utf16le",
	57: "cp1256",
	59: "cp1257",
	60: "utf32",
	63: "binary",
	92: "geostd8",
	95: "cp932",
	97: "eucjpms",
}

type mysqlConn struct {
	sync.Mutex
	io               *packetIO
	buf              *bufPool
	netConn          net.Conn
	affectedRows     uint64
	insertID         uint64
	dbname           string
	cfg              *conf.NodeConf
	maxPacketAllowed int
	maxWriteSize     int
	writeTimeout     time.Duration
	collation        string
	flags            uint32
	status           uint16
	sequence         uint8
	strict           bool
	pushedAt         time.Time
	lastError        error
}

func (mc *mysqlConn) Close() error {
	mc.Lock()
	if mc.netConn != nil {
		mc.netConn.Close()
	}
	mc.netConn = nil
	mc.buf.Put()
	mc.buf = nil
	mc.Unlock()
	return nil
}

func (mc *mysqlConn) GetStatus() uint16 {
	mc.Lock()
	defer mc.Unlock()
	return mc.status
}

func (mc *mysqlConn) Begin() (Tx, error) {
	if mc.netConn == nil {
		return nil, mysql.ErrBadConn
	}
	_, err := mc.Exec([]byte("BEGIN"))
	if err == nil {
		return &mysqlTx{mc}, err
	}
	return nil, err
}

func (mc *mysqlConn) Prepare(query string) (data [][]byte, stmt *mysqlStmt, err error) {
	mc.Lock()
	defer func() {
		mc.lastError = err
		mc.Unlock()
	}()
	if mc.netConn == nil {
		return nil, nil, mysql.ErrBadConn
	}

	// Send command
	err = mc.writeCommandPacketStr(mysql.ComStmtPrepare, query)
	if err != nil {
		return nil, nil, err
	}

	stmt = &mysqlStmt{
		mysqlConn: mc,
	}

	// Read Result
	head, columnCount, err := stmt.readPrepareResultPacket()
	data = make([][]byte, 0, columnCount*2+1)
	data = append(data, head)
	if err == nil {
		if stmt.paramCount > 0 {
			if data, err = mc.readUntilEOF(data); err != nil {
				return data, nil, err
			}
		}

		if columnCount > 0 {
			data, err = mc.readUntilEOF(data)
		}
	}
	return data, stmt, err
}

func (mc *mysqlConn) expired(timeout time.Duration) bool {
	if timeout <= 0 {
		return false
	}
	return mc.pushedAt.Add(timeout).Before(nowFunc())
}

func (mc *mysqlConn) UseDB(name string, collate string) (err error) {
	if mc.dbname == name && mc.collation == collate {
		return
	}
	if name == "" {
		name = mc.dbname
	}
	//set to default
	if collate == "" {
		collate = mysql.DefaultCharset
	}

	sli := make([]byte, 0, 64)
	buf := bytes.NewBuffer(sli)

	//use db
	if mc.dbname != name {
		buf.WriteByte(mysql.ComInitDB)
		buf.WriteString(name)
		if _, err = mc.Exec(buf.Bytes()); err != nil {
			return
		} else {
			mc.dbname = name
		}
	}

	if mc.collation != collate {
		buf.Reset() //reset if used before
		buf.WriteByte(mysql.ComQuery)
		buf.WriteString(fmt.Sprintf("SET NAMES %s;", collate))
		if _, err = mc.Exec(buf.Bytes()); err != nil {
			return
		} else {
			mc.collation = collate
		}
	}

	return
}

//Exec execute the cmd,and return the read all the  packet.
func (mc *mysqlConn) Exec(data []byte) (result []byte, err error) {
	cmd := data[0]
	arg := string(data[1:])
	mc.Lock()
	defer func() {
		mc.lastError = err
		mc.Unlock()
	}()

	if err := mc.writeCommandPacketStr(cmd, arg); err != nil {
		return nil, err
	}

	result, _, err = mc.readResultSetHeaderPacket()
	if err != nil {
		return nil, err
	}
	return result, err
}

//Exec execute the cmd,and return the read all the  packet.
func (mc *mysqlConn) Query(data []byte) (result [][]byte, err error) {
	cmd := data[0]
	arg := string(data[1:])
	mc.Lock()
	defer func() {
		mc.lastError = err
		mc.Unlock()
	}()

	if err := mc.writeCommandPacketStr(cmd, arg); err != nil {
		return nil, err
	}
	reshd, resLen, err := mc.readResultSetHeaderPacket()
	if err != nil {
		return nil, err
	}
	result = make([][]byte, 0, resLen*2+1)
	result = append(result, reshd)

	if resLen > 0 {
		// columns
		if result, err = mc.readResultSetPacket(result, resLen); err != nil {
			return nil, err
		}
		// rows
		if result, err = mc.readUntilEOF(result); err != nil {
			return nil, err
		}
	}
	//log.Debug("result:", result)
	return result, err
}

//ping when connection is free can call this, otherwise the maybe case buffer busy.
func (mc *mysqlConn) ping() (err error) {
	mc.Lock()
	err = mc.writeCommandPacket(mysql.ComPing)
	if err == nil {
		_, err = mc.io.readPacket()
	}
	mc.lastError = err
	mc.Unlock()
	return err
}

func (mc *mysqlConn) readResultSetPacket(res [][]byte, count int) ([][]byte, error) {

	for i := 0; ; i++ {
		data, err := mc.io.readPacket()
		if err != nil {
			return nil, err
		}
		res = append(res, data)
		//log.Debugf("res len:%v, last:%v", len(res), data)

		// EOF Packet
		if data[0] == mysql.EOFHeader && (len(data) == 5 || len(data) == 1) {
			if i == count {
				return res, nil
			}
			return res, fmt.Errorf("column count mismatch n:%d len:%d", count, i)
		}
	}
}

// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF(res [][]byte) ([][]byte, error) {
	for {
		data, err := mc.io.readPacket()
		res = append(res, data)
		// No Err and no EOF Packet
		if err == nil && data[0] != mysql.EOFHeader {
			continue
		}
		if err == nil && data[0] == mysql.EOFHeader && len(data) == 5 {
			mc.status = readStatus(data[3:])
		}

		return res, err // Err or EOF
	}
}

// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() ([]byte, int, error) {
	data, err := mc.io.readPacket()
	if err == nil {
		switch data[0] {

		case mysql.OKHeader:
			return data, 0, mc.handleOkPacket(data)

		case mysql.ErrHeader:
			return data, 0, mc.handleErrorPacket(data)

		case mysql.LocalInFileHeader:
			return data, 0, fmt.Errorf("not support")
		}

		// column count
		num, _, n := readLengthEncodedInteger(data)
		if n-len(data) == 0 {
			return data, int(num), nil
		}
		logs.Debug(data, num, n)
		return data, 0, mysql.ErrMalformPacket
	}
	return data, 0, err
}

// Error Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
	if data[0] != mysql.ErrHeader {
		return mysql.ErrMalformPacket
	}

	// 0xff [1 byte]

	// Error Number [16 bit uint]
	errno := binary.LittleEndian.Uint16(data[1:3])

	pos := 3

	// SQL State [optional: # + 5bytes string]
	if data[3] == 0x23 {
		//sqlstate := string(data[4 : 4+5])
		pos = 9
	}

	// Error Message [string]
	return &mysql.SQLError{
		Code:    errno,
		State:   mysql.MySQLState[errno],
		Message: string(data[pos:]),
	}
}

// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
	data, err := mc.io.readPacket()
	if err != nil {
		return nil, err
	}

	if data[0] == mysql.ErrHeader {
		return nil, mc.handleErrorPacket(data)
	}

	// protocol version [1 byte]
	if data[0] < MinProtocolVersion {
		return nil, fmt.Errorf(
			"unsupported protocol version %d. Version %d or higher is required",
			data[0],
			MinProtocolVersion,
		)
	}

	// server version [null terminated string]
	// connection id [4 bytes]
	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4

	// first part of the password cipher [8 bytes]
	cipher := data[pos : pos+8]

	// (filler) always 0x00 [1 byte]
	pos += 8 + 1

	// capability flags (lower 2 bytes) [2 bytes]
	mc.flags = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
	if mc.flags&mysql.ClientProtocol41 == 0 {
		return nil, dmysql.ErrOldProtocol
	}
	// if mc.flags&mysql.ClientSSL == 0 && mc.cfg.tls != nil {
	// 	return nil, mysql.ErrNoTLS
	// }
	pos += 2

	if len(data) > pos {
		// character set [1 byte]
		// status flags [2 bytes]
		// capability flags (upper 2 bytes) [2 bytes]
		// length of auth-plugin-data [1 byte]
		// reserved (all [00]) [10 bytes]
		pos += 1 + 2 + 2 + 1 + 10

		// second part of the password cipher [mininum 13 bytes],
		// where len=MAX(13, length of auth-plugin-data - 8)
		//
		// The web documentation is ambiguous about the length. However,
		// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
		// the 13th byte is "\0 byte, terminating the second part of
		// a scramble". So the second part of the password cipher is
		// a NULL terminated string that's at least 13 bytes with the
		// last byte being NULL.
		//
		// The official Python library uses the fixed length 12
		// which seems to work but technically could have a hidden bug.
		cipher = append(cipher, data[pos:pos+12]...)

		// TODO: Verify string termination
		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
		// \NUL otherwise
		//
		//if data[len(data)-1] == 0 {
		//	return
		//}
		//return ErrMalformPkt

		// make a memory safe copy of the cipher slice
		var b [20]byte
		copy(b[:], cipher)
		return b[:], nil
	}

	// make a memory safe copy of the cipher slice
	var b [8]byte
	copy(b[:], cipher)
	return b[:], nil
}

// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
	// Adjust client flags based on server support
	clientFlags := mysql.ClientProtocol41 |
		mysql.ClientSecureConnection |
		mysql.ClientLongPassword |
		mysql.ClientTransactions |
		//mysql.ClientLocalFiles |
		mysql.ClientPluginAuth |
		mysql.ClientMultiResults |
		mysql.ClientMultiStatements |
		mc.flags&mysql.ClientLongFlag

	// if mc.cfg.ClientFoundRows {
	// 	clientFlags |= mysql.ClientFoundRows
	// }

	// To enable TLS / SSL
	// if mc.cfg.tls != nil {
	// 	clientFlags |= ClientSSL
	// }

	// User Password
	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))

	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1

	// To specify a db name
	// if n := len(mc.cfg.DBName); n > 0 {
	// 	clientFlags |= mysql.ClientConnectWithDB
	// 	pktLen += n + 1
	// }

	// Calculate packet length and get buffer with that size
	data := mc.buf.GetSize(pktLen + 4)
	// data := make([]byte, pktLen+4)

	// ClientFlags [32 bit]
	data[4] = byte(clientFlags)
	data[5] = byte(clientFlags >> 8)
	data[6] = byte(clientFlags >> 16)
	data[7] = byte(clientFlags >> 24)

	// MaxPacketSize [32 bit] (none)
	data[8] = 0x00
	data[9] = 0x00
	data[10] = 0x00
	data[11] = 0x00

	// Charset [1 byte]
	var found bool
	data[12], found = mysql.CollationNames[mysql.DefaultCollationName]
	if !found {
		// Note possibility for false negatives:
		// could be triggered  although the collation is valid if the
		// collations map does not contain entries the server supports.
		return errors.New("unknown collation")
	}

	// SSL Connection Request Packet
	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
	// if mc.cfg.tls != nil {
	// 	// Send TLS / SSL request packet
	// 	if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
	// 		return err
	// 	}

	// 	// Switch to TLS
	// 	tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
	// 	if err := tlsConn.Handshake(); err != nil {
	// 		return err
	// 	}
	// 	mc.netConn = tlsConn
	// 	mc.buf.nc = tlsConn
	// }

	// Filler [23 bytes] (all 0x00)
	pos := 13
	for ; pos < 13+23; pos++ {
		data[pos] = 0
	}

	// User [null terminated string]
	if len(mc.cfg.User) > 0 {
		pos += copy(data[pos:], mc.cfg.User)
	}
	data[pos] = 0x00
	pos++

	// ScrambleBuffer [length encoded integer]
	data[pos] = byte(len(scrambleBuff))
	pos += 1 + copy(data[pos+1:], scrambleBuff)

	// Databasename [null terminated string]
	// if len(mc.cfg.DBName) > 0 {
	// 	pos += copy(data[pos:], mc.cfg.DBName)
	// 	data[pos] = 0x00
	// 	pos++
	// }
	// mc.dbname = mc.cfg.DBName

	// Assume native client during response
	pos += copy(data[pos:], "mysql_native_password")
	data[pos] = 0x00

	// Send Auth packet
	return mc.io.writePacket(data)
}

func (mc *mysqlConn) readInitOK() error {
	data, err := mc.io.readPacket()
	if err == nil {
		// packet indicator
		switch data[0] {
		case mysql.OKHeader:
			return nil
		case mysql.EOFHeader:
			if len(data) > 1 {
				plugin := string(data[1:bytes.IndexByte(data, 0x00)])
				if plugin == "mysql_old_password" {
					// using old_passwords
					return dmysql.ErrOldPassword
				} else if plugin == "mysql_clear_password" {
					// using clear text password
					return dmysql.ErrCleartextPassword
				} else {
					return dmysql.ErrUnknownPlugin
				}
			} else {
				return dmysql.ErrOldPassword
			}

		default: // Error otherwise
			return mc.handleErrorPacket(data)
		}
	}
	return err
}

func readStatus(b []byte) uint16 {
	return uint16(b[0]) | uint16(b[1])<<8
}
